import os
from matplotlib import pyplot as plt
# https://www.cnblogs.com/techflow/p/13818748.html
if __name__ == "__main__":
    train_loss_path = 'logs/loss/epoch_loss_2021_11_25_11_25_36.txt'
    val_loss_path = 'logs/loss/epoch_val_loss_2021_11_25_11_25_36.txt'
    train_loss_txt = open(train_loss_path)
    val_loss_txt = open(val_loss_path)
    train_loss = []
    val_loss = []
    for line in train_loss_txt:
        train_loss.append(eval(line))
    for line in val_loss_txt:
        val_loss.append(eval(line))
    print(train_loss)
    print(val_loss)
    print(len(train_loss))
    iters = range(1, len(train_loss)+1)
    plt.figure()
    plt.plot(iters, train_loss, 'red', linewidth=2, label='train loss')
    plt.plot(iters, val_loss, 'coral', linewidth=2, label='val loss')
    plt.grid(True)
    plt.xticks(range(0, 101, 10))
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Loss Curve')
    plt.legend(loc="upper right")
    plt.savefig('loss_curve.svg')
    plt.show()

